/**
 * Copyright 2023-2023 Huawei Technologies Co., Ltd
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 * http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

#include "hiai_ndk_align.h"

#include <map>

namespace hiai {

HIAI_DataType HIAIAlign::ConvertNNDataTypeToHIAI(OH_NN_DataType dataType)
{
    static std::map<OH_NN_DataType, HIAI_DataType> nnToHiaiDataType = {
        {OH_NN_UNKNOWN, HIAI_DATATYPE_UINT8}, // 默认值
        {OH_NN_BOOL, HIAI_DATATYPE_BOOL},
        {OH_NN_INT8, HIAI_DATATYPE_INT8},
        {OH_NN_INT16, HIAI_DATATYPE_INT16},
        {OH_NN_INT32, HIAI_DATATYPE_INT32},
        {OH_NN_INT64, HIAI_DATATYPE_INT64},
        {OH_NN_UINT8, HIAI_DATATYPE_UINT8},
        {OH_NN_UINT32, HIAI_DATATYPE_UINT32},
        {OH_NN_FLOAT16, HIAI_DATATYPE_FLOAT16},
        {OH_NN_FLOAT32, HIAI_DATATYPE_FLOAT32},
    };

    auto iter = nnToHiaiDataType.find(dataType);
    if (iter == nnToHiaiDataType.end()) {
        return HIAI_DATATYPE_NOT_SUPPORTED;
    }
    return iter->second;
}

OH_NN_DataType HIAIAlign::ConvertHIAIDataTypeToNN(HIAI_DataType dataType)
{
    static std::map<HIAI_DataType, OH_NN_DataType> hiaiToNNDataType = {
        {HIAI_DATATYPE_BOOL, OH_NN_BOOL},
        {HIAI_DATATYPE_INT8, OH_NN_INT8},
        {HIAI_DATATYPE_INT16, OH_NN_INT16},
        {HIAI_DATATYPE_INT32, OH_NN_INT32},
        {HIAI_DATATYPE_INT64, OH_NN_INT64},
        {HIAI_DATATYPE_UINT8, OH_NN_UINT8},
        {HIAI_DATATYPE_UINT32, OH_NN_UINT32},
        {HIAI_DATATYPE_FLOAT16, OH_NN_FLOAT16},
        {HIAI_DATATYPE_FLOAT32, OH_NN_FLOAT32},
    };

    auto iter = hiaiToNNDataType.find(dataType);
    if (iter == hiaiToNNDataType.end()) {
        return OH_NN_UNKNOWN;
    }
    return iter->second;
}

HIAI_Format HIAIAlign::ConvertNNFormatToHIAI(OH_NN_Format format)
{
    static std::map<OH_NN_Format, HIAI_Format> nnToHiaiFormat = {
        {OH_NN_FORMAT_NONE, HIAI_FORMAT_NCHW}, // 默认值
        {OH_NN_FORMAT_NCHW, HIAI_FORMAT_NCHW},
        {OH_NN_FORMAT_NHWC, HIAI_FORMAT_NHWC},
        {OH_NN_FORMAT_ND, HIAI_FORMAT_ND},
    };

    auto iter = nnToHiaiFormat.find(format);
    if (iter == nnToHiaiFormat.end()) {
        return HIAI_FORMAT_NOT_SUPPORT;
    }
    return iter->second;
}

OH_NN_Format HIAIAlign::ConvertHIAIFormatToNN(HIAI_Format format)
{
    static std::map<HIAI_Format, OH_NN_Format> hiaiToNNFormat = {
        {HIAI_FORMAT_NCHW, OH_NN_FORMAT_NCHW},
        {HIAI_FORMAT_NHWC, OH_NN_FORMAT_NHWC},
        {HIAI_FORMAT_ND, OH_NN_FORMAT_ND},
    };

    auto iter = hiaiToNNFormat.find(format);
    if (iter == hiaiToNNFormat.end()) {
        return OH_NN_FORMAT_NONE;
    }
    return iter->second;
}

OH_NN_Priority HIAIAlign::ConvertHiaiPriorityToNNPriority(HIAI_ModelPriority priority)
{
    std::map<HIAI_ModelPriority, OH_NN_Priority> nnPriorityMap = {
        {HIAI_PRIORITY_LOW, OH_NN_PRIORITY_LOW},
        {HIAI_PRIORITY_MIDDLE, OH_NN_PRIORITY_MEDIUM},
        {HIAI_PRIORITY_HIGH, OH_NN_PRIORITY_HIGH},
    };

    auto iter = nnPriorityMap.find(priority);
    if (iter == nnPriorityMap.end()) {
        return OH_NN_PRIORITY_LOW; // default low
    }

    return iter->second;
}
}